#!/usr/bin/env python3


from __future__ import annotations
import functools
from copy import deepcopy
from pathlib import Path
from typing import List

import os
import numpy as np
import torch
import wandb
from rpi import logger
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (
    GaussianHeadWithStateIndependentCovariance,
    SoftmaxCategoricalHead,
)
from rpi.scripts.sweep.default_args import Args
from rpi.scripts.train import Factory, get_expert
from . import NewStateDetector, extract_states, simple_make_env


def annotate(frames: List[np.ndarray], rewards: List[float], note: str = ''):
    """Add step in each frame."""
    import cv2
    cum_rew = 0
    for step, (frame, reward) in enumerate(zip(frames, rewards)):
        cum_rew += reward
        frame = cv2.putText(
            frame, note, org=(0, 20), fontFace=3, fontScale=.4, color=(0, 255, 0), thickness=1
        )
        frame = cv2.putText(
            frame, f"step: {step + 1} / {len(frames)}", org=(0, 40), fontFace=3, fontScale=.4, color=(0, 255, 0), thickness=1
        )
        frame = cv2.putText(
            frame, f"cum_rew: {cum_rew}", org=(0, 60), fontFace=3, fontScale=.4, color=(0, 255, 0), thickness=1
        )
        frames[step] = frame
    return frames


def main(
    env_name: str,
    load_steps: List[int],
    save_dir,
    seed: int,
    num_episodes: int = 100,
    max_episode_len: int = 1000,
        deterministic_experts: bool = False
):
    # Prepare make_env function
    make_env, state_dim, act_dim, env_id = simple_make_env(env_name, default_seed=seed)

    # Load pretrained experts
    policy_head = GaussianHeadWithStateIndependentCovariance(
        action_size=act_dim,
        var_type="diagonal",
        var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
        var_param_init=0,  # log std = 0 => std = 1
    )

    # Get a mapping from expert step to its rollout trajectories
    # Load from file if exists, otherwise generate the data
    for load_step in load_steps:
        # Re-insntantiate environment and reset random seed
        set_random_seed(seed)
        env = make_env()

        # Load expert
        expert = get_expert(
            state_dim,
            act_dim,
            deepcopy(policy_head),
            Path(Args.experts_dir) / env_id / f"step_{load_step:06d}.pt",
            obs_normalizer=None,
        )

        fpath = save_dir / env_id / f'deterministic-{deterministic_experts}'/ f"expert-{load_step:06d}.pt"
        if fpath.exists():
            logger.info(f"Loading episodes from {fpath}...")
            obj = torch.load(fpath)
            episodes = obj["episodes"]
            ep_returns = obj["ep_returns"]
        else:
            # Rollout `num_episodes * 2` episodes per expert and save the trajectories
            logger.info(f"Rolling out expert {load_step}...")
            episodes = [
                rollout_single_ep(
                    env,
                    functools.partial(expert.act, mode=deterministic_experts),
                    max_episode_len,
                    save_video=(ep == 0)
                )
                for ep in range(num_episodes)
            ]
            # Get stats
            ep_returns = np.array(
                [sum([tr["reward"] for tr in transitions]) for transitions in episodes]
            )

            logger.info(f"Saving the episodes to {fpath}...")
            fpath.parent.mkdir(mode=0o775, parents=True, exist_ok=True)
            torch.save({"episodes": episodes, "ep_returns": ep_returns}, fpath)

        wandb.log(
            {
                "expert-step": load_step,
                "ep_returns-mean": ep_returns.mean(),
                "ep_returns-stddev": ep_returns.std(),
                "ep_returns-median": np.median(ep_returns),
                "ep_returns-hist": wandb.Histogram(ep_returns.tolist()),
            }
        )

        # Visualize the video on wandb
        frames = [tr['frame'] for tr in episodes[0]]
        rewards = [tr['reward'] for tr in episodes[0]]
        frames = annotate(frames, rewards=rewards, note=f'expert-{load_step:06d}')
        wandb.log({'video': wandb.Video(np.asarray(frames).transpose(0, 3, 1, 2), fps=20, format='mp4'), 'expert-step': load_step})


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("envname", help="Environment name (e.g., dmc:Cheetah-run-v1)")
    parser.add_argument("--load-steps", nargs='+', type=int)
    parser.add_argument("--mode", action='store_true')
    pyargs = parser.parse_args()

    print("pyargs", pyargs)

    save_dir = Path("/expert-rollouts")

    if "CUDA_VISIBLE_DEVICES" not in os.environ:
        avail_gpus = [0, 1, 2, 3]
        cvd = avail_gpus[0]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cvd)

    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project="alops-experts",
        config=vars(Args),
    )

    main(
        pyargs.envname,
        pyargs.load_steps,
        save_dir,
        seed=0,
        deterministic_experts=pyargs.mode,
    )
